# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Colocated Python function API implementation."""

from __future__ import annotations

from collections.abc import Callable, Sequence
import dataclasses
import functools
import inspect
import random
import threading
from typing import Any

import jax
from jax._src import api
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import pxla
from jax._src.lib import xla_client as xc
from jax._src.traceback_util import api_boundary
from jax._src.util import wraps
from jax.experimental.colocated_python import func_backend
from jax.experimental.colocated_python.serialization import _deserialize, _deserialize_specs, _make_specs_for_serialized_specs, _serialize, _serialize_specs
from jax.extend.backend import register_backend_cache as jax_register_backend_cache
from jax.extend.ifrt_programs import ifrt_programs

ShapeDtypeStructTree = Any  # PyTree[api.ShapeDtypeStruct]


@dataclasses.dataclass(frozen=True, slots=True)
class FunctionInfo:
  """User function wrapped by colocated_python."""

  fun: Callable[..., Any]
  fun_sourceinfo: str | None
  fun_signature: inspect.Signature | None


@dataclasses.dataclass(frozen=True, slots=True)
class Specialization:
  """Specialization for a colocated_python function."""

  in_specs_treedef: tree_util.PyTreeDef | None = None
  in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
  out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None
  out_specs_treedef: tree_util.PyTreeDef | None = None
  out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None
  devices: xc.DeviceList | None = None

  def update(
      self,
      *,
      in_specs_treedef: tree_util.PyTreeDef | None = None,
      in_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
      out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
      out_specs_treedef: tree_util.PyTreeDef | None = None,
      out_specs_leaves: tuple[api.ShapeDtypeStruct, ...] | None = None,
      devices: Sequence[jax.Device] | xc.DeviceList | None = None,
  ):
    """Creates a new specialization with overrides."""
    if in_specs_treedef is None:
      in_specs_treedef = self.in_specs_treedef
    elif self.in_specs_treedef is not None:
      raise ValueError("in_specs already specified")
    if in_specs_leaves is None:
      in_specs_leaves = self.in_specs_leaves
    elif self.in_specs_leaves is not None:
      raise ValueError("in_specs already specified")

    if out_specs_fn is None:
      out_specs_fn = self.out_specs_fn
    elif self.out_specs_fn is not None:
      raise ValueError("out_specs_fn already specified")

    if out_specs_treedef is None:
      out_specs_treedef = self.out_specs_treedef
    elif self.out_specs_treedef is not None:
      raise ValueError("out_specs already specified")
    if out_specs_leaves is None:
      out_specs_leaves = self.out_specs_leaves
    elif self.out_specs_leaves is not None:
      raise ValueError("out_specs already specified")

    if devices is None:
      devices = self.devices
    elif self.devices is not None:
      raise ValueError("devices already specified")
    elif not isinstance(devices, xc.DeviceList):
      devices = xc.DeviceList(tuple(devices))

    return Specialization(
        in_specs_treedef,
        in_specs_leaves,
        out_specs_fn,
        out_specs_treedef,
        out_specs_leaves,
        devices,
    )


def _get_spec(x: Any) -> api.ShapeDtypeStruct:
  """Extracts a spec for a value, which must be a JAX Array."""
  # TODO(hyeontaek): Allow Python values and automatically apply `shard_arg`
  # with a suitable sharding and layout.
  if not isinstance(x, jax.Array):
    raise ValueError(
        "colocated_python only supports jax.Array as input and output, but got"
        f" {type(x)}."
    )
  return api.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding)


def _infer_devices_from_args(args: Sequence[Any]) -> xc.DeviceList | None:
  """Returns a representative device list from function call arguments."""
  device_list_set: set[xc.DeviceList] = set()
  for x in args:
    sharding = getattr(x, "sharding", None)
    if sharding is not None:
      device_list_set.add(x.sharding._internal_device_list)
  if not device_list_set:
    return None
  if len(device_list_set) != 1:
    raise ValueError(
        "All arguments must use the same device list, but got"
        f" multiple device lists: {device_list_set}."
    )
  return device_list_set.pop()


def _compile_to_executable(
    name: str,
    fun: Callable[..., Any],
    in_specs_treedef: tree_util.PyTreeDef,
    in_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
    out_specs_treedef: tree_util.PyTreeDef,
    out_specs_leaves: tuple[api.ShapeDtypeStruct, ...],
    devices: xc.DeviceList,
) -> Callable[..., Any]:
  """Compiles a Python function into a runtime executable."""
  fun_and_specialization = (
      fun,
      in_specs_treedef,
      in_specs_leaves,
      out_specs_treedef,
      out_specs_leaves,
      devices,
  )
  pickled_function = _serialize(fun_and_specialization)
  program = ifrt_programs.make_colocated_python_program(
      name, pickled_function, devices, in_specs_leaves, out_specs_leaves
  )
  ifrt_client = devices[0].client
  out_sdss = tuple(
      jax.core.ShapedArray(sds.shape, sds.dtype) for sds in out_specs_leaves
  )
  out_shardings = tuple(sds.sharding for sds in out_specs_leaves)
  try:
    compile_options = ifrt_programs.make_colocated_python_compile_options()
    loaded_executable = ifrt_client.compile_ifrt_program(
        program, compile_options
    )
    out_handlers = pxla.global_avals_to_results_handler(
        out_sdss, out_shardings, committed=True  # type: ignore
    ).handlers

    def call(*args, **kwargs):
      args_leaves = tree_util.tree_leaves((args, kwargs))
      execute_result = loaded_executable.execute_sharded(
          args_leaves, with_tokens=False
      )
      results = execute_result.consume_with_handlers(out_handlers)
      return tree_util.tree_unflatten(out_specs_treedef, results)

    return call
  except jax.errors.JaxRuntimeError as e:
    # TODO(hyeontaek): Implement colocated Python support in McJAX and remove
    # this fallback path.
    if "PjRtCompiler requires an HloProgram" in str(e):
      return _deserialize(pickled_function)[0]
    raise


def _make_output_specs_and_push_result_fun(
    info: FunctionInfo,
    specialization: Specialization,
    uid: int,
) -> Callable[..., Any]:
  """Creates a function that computes output specs and pushes the result to the result store."""
  assert specialization.in_specs_treedef is not None
  assert specialization.in_specs_leaves is not None
  assert specialization.out_specs_treedef is None
  assert specialization.out_specs_leaves is None
  assert specialization.devices is not None

  devices = specialization.devices

  def lowered_fun(*args, **kwargs) -> jax.Array:
    result = info.fun(*args, **kwargs)
    result_leaves, out_treedef = tree_util.tree_flatten(result)
    out_spec_leaves = tuple(_get_spec(x) for x in result_leaves)
    func_backend.SINGLETON_RESULT_STORE.push(uid, result_leaves)
    return _serialize_specs(out_treedef, out_spec_leaves, devices)

  out_specs_leaves, out_specs_treedef = tree_util.tree_flatten(
      _make_specs_for_serialized_specs(specialization.devices),
  )
  name = getattr(info.fun, "__name__", "unknown")
  name = f"{name}_output_specs_and_push_result"
  return _compile_to_executable(
      name=name,
      fun=lowered_fun,
      in_specs_treedef=specialization.in_specs_treedef,
      in_specs_leaves=specialization.in_specs_leaves,
      out_specs_treedef=out_specs_treedef,
      out_specs_leaves=tuple(out_specs_leaves),
      devices=specialization.devices,
  )


def _make_pop_result_fun(
    info: FunctionInfo,
    specialization: Specialization,
    uid: int,
) -> Callable[..., Any]:
  """Makes a function that pops results from the result store."""
  assert specialization.out_specs_treedef is not None
  assert specialization.out_specs_leaves is not None
  assert specialization.devices is not None

  out_specs_treedef = specialization.out_specs_treedef

  def lowered_fun():
    result_leaves = func_backend.SINGLETON_RESULT_STORE.pop(uid)
    return tree_util.tree_unflatten(out_specs_treedef, result_leaves)

  in_specs_leaves, in_specs_treedef = tree_util.tree_flatten((
      # args
      (),
      # kwargs
      {},
  ))
  name = getattr(info.fun, "__name__", "unknown")
  name = f"{name}_pop_result"
  return _compile_to_executable(
      name=name,
      fun=lowered_fun,
      in_specs_treedef=in_specs_treedef,
      in_specs_leaves=tuple(in_specs_leaves),
      out_specs_treedef=specialization.out_specs_treedef,
      out_specs_leaves=specialization.out_specs_leaves,
      devices=specialization.devices,
  )


def _make_async_execution_fun(
    info: FunctionInfo,
    specialization: Specialization,
) -> Callable[..., Any]:
  """Makes a function that asynchronously executes the function."""
  assert specialization.in_specs_treedef is not None
  assert specialization.in_specs_leaves is not None
  assert specialization.out_specs_treedef is not None
  assert specialization.out_specs_leaves is not None
  assert specialization.devices is not None

  name = getattr(info.fun, "__name__", "unknown")
  return _compile_to_executable(
      name=name,
      fun=info.fun,
      in_specs_treedef=specialization.in_specs_treedef,
      in_specs_leaves=specialization.in_specs_leaves,
      out_specs_treedef=specialization.out_specs_treedef,
      out_specs_leaves=specialization.out_specs_leaves,
      devices=specialization.devices,
  )


def _uncached_get_specialized_func(
    info: FunctionInfo,
    specialization: Specialization,
) -> Callable[..., Any]:
  """Returns a specialized function for the given specialization."""
  util.test_event("colocated_python_func._get_specialized_func")
  assert specialization.in_specs_treedef is not None
  assert specialization.in_specs_leaves is not None
  assert specialization.devices is not None
  uid = random.getrandbits(63)

  mutex = threading.Lock()
  # Asynchronous execution function that has known output_specs.
  async_execution_func = None

  def specialized_func(*args, **kwargs):
    """Specialized function to be executed with given args and kwargs."""
    nonlocal specialization, async_execution_func
    with mutex:
      if async_execution_func is None:
        if specialization.out_specs_treedef is None:
          if specialization.out_specs_fn is None:
            output_specs_and_push_result_fun = (
                _make_output_specs_and_push_result_fun(
                    info, specialization, uid
                )
            )
            serialized_out_specs = output_specs_and_push_result_fun(
                *args, **kwargs
            )

            # Waits for the output_specs. This may block.
            out_specs_treedef, out_specs_leaves = _deserialize_specs(
                serialized_out_specs
            )

            # Subsequent calls would use async_execution_func with discovered
            # output_specs.
            specialization = specialization.update(
                out_specs_treedef=out_specs_treedef,
                out_specs_leaves=out_specs_leaves,
            )
            async_execution_func = _make_async_execution_fun(
                info, specialization
            )

            # Hold the PyExecutable until async_execution_fun is called at
            # least once, so the number of _OBJECT_STORE references at the
            # backend does not drop to 0.
            async_execution_func.output_specs_and_push_result_fun = (
                output_specs_and_push_result_fun
            )

            return _make_pop_result_fun(info, specialization, uid)()
          else:
            # Compute out_specs using out_specs_fn and inputs.
            args_specs, kwargs_specs = tree_util.tree_map(
                _get_spec, (args, kwargs)
            )
            out_specs = specialization.out_specs_fn(*args_specs, **kwargs_specs)
            # Type checking is ignored to silence mypy error: Incompatible types
            # in assignment (expression has type "list[Any]", variable has type
            # "tuple[ShapeDtypeStruct, ...]")  [assignment]
            out_specs_leaves, out_specs_treedef = tree_util.tree_flatten(  # type: ignore[assignment]
                out_specs
            )
            specialization = specialization.update(
                out_specs_treedef=out_specs_treedef,
                out_specs_leaves=tuple(out_specs_leaves),
            )
            async_execution_func = _make_async_execution_fun(
                info, specialization
            )
            # Fall-through.
        else:
          async_execution_func = _make_async_execution_fun(info, specialization)
          # Fall-through.

    # Asynchronous execution runs outside of the mutex to allow concurrent
    # execution for inline executors.
    result = async_execution_func(*args, **kwargs)
    with mutex:
      async_execution_func.output_specs_and_push_result_fun = None
    return result

  return specialized_func


class _CachedGetSpecializedFunction:
  """Manages cached versions of `_uncached_get_specialized_func`.

  This class holds a collection of caches, each identified by a unique ID, and
  presents itself as a single cache to JAX's `register_backend_cache`. One can
  clear individual caches identified by the UID, using the `cache_remove(uid)`
  method. JAX's `clear_backend_cache()` will clear all caches.
  """

  def __init__(self):
    self._lock = threading.Lock()
    self._caches: dict[int, Any] = {}
    jax_register_backend_cache(self, "colocated_python_specialized_func_cache")

  def cache_clear(self):
    self._caches.clear()

  def cache_remove(self, held_by: int):
    try:
      self._caches.pop(held_by)
    except KeyError:
      pass

  def get(self, held_by: int) -> Callable[..., Any]:
    with self._lock:
      try:
        return self._caches[held_by]
      except KeyError:
        cache = functools.cache(_uncached_get_specialized_func)
        self._caches[held_by] = cache
        return cache


_SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION = _CachedGetSpecializedFunction()


class _CachedColocatedFunctionMaker:
  """Function maker for colocated Python functions.

  Generated functions are stored (cached) indefinitely so that they can be
  reused, until the cache is dropped.
  """

  def __init__(self, held_by: int | None):
    self.held_by = held_by
    if held_by is None:
      self._get_specialized_func = jax._src.util.cache(
          max_size=None, trace_context_in_key=False
      )(_uncached_get_specialized_func)
    else:
      self._get_specialized_func = (
          _SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION.get(held_by)
      )

  def __del__(self):
    if self.held_by is not None:
      _SINGLETON_CACHED_GET_SPECIALIZED_FUNCTION.cache_remove(self.held_by)

  def _make_callable(
      self,
      info: FunctionInfo,
      specialization: Specialization,
  ):
    """Internal implementation of make_callable."""

    def specialize(
        in_specs: ShapeDtypeStructTree | None = None,
        out_specs_fn: Callable[..., ShapeDtypeStructTree] | None = None,
        devices: Sequence[jax.Device] | None = None,
    ):
      """Returns a colocated Python callable with extra specialization.

      Args:
        in_specs: Optionally specifies the expected input specs. Input specs are
          expressed as a `PyTree[ShapeDtypeStruct]` for `(args, kwargs)` of a
          function call.
        out_specs_fn: Optionally specifies a function that computes the output
          specs from input specs. If unspecified, colocated Python will compute
          the output specs during the very first execution, and this execution
          will be synchronous.
        devices: Optionally specifies the devices to execute the function on.
          Must be provided if `in_specs` has no leaves because devices cannot be
          inferred from input specs or arguments.

      Returns:
        A colocated Python callable with extra specialization.
      """
      # TODO(hyeontaek): Allow unspecified devices for zero-leaf `in_specs` if
      # `out_specs_fn(in_specs)` returns at least one leaf that we can use for
      # inferring `devices`.
      if in_specs is None:
        in_specs_leaves, in_specs_treedef = None, None
      else:
        in_specs_leaves_list, in_specs_treedef = tree_util.tree_flatten(
            in_specs
        )
        in_specs_leaves = tuple(in_specs_leaves_list)
      return self._make_callable(
          info,
          specialization.update(
              in_specs_treedef=in_specs_treedef,
              in_specs_leaves=in_specs_leaves,
              out_specs_fn=out_specs_fn,
              devices=devices,
          ),
      )

    @api_boundary
    def __call__(*args, **kwargs):
      """Executes the given Python function on the same devices as the arguments or as specialized.

      If the callable has not been specialized with output shapes and shardings
      (see `specialize` above), the very first call will run synchronously to
      discover output shapes and shardings, and will run asynchronously after.
      If
      specialized with output shapes and shardings, every execution of the
      callable will be asynchronous.
      """
      args_leaves, in_specs_treedef = tree_util.tree_flatten((args, kwargs))

      in_specs_leaves = tuple(_get_spec(x) for x in args_leaves)
      if specialization.in_specs_treedef is None:
        # Allow input polymorphism by applying input_specs specialization
        # temporarily for this call.
        return self._make_callable(
            info,
            specialization.update(
                in_specs_treedef=in_specs_treedef,
                in_specs_leaves=in_specs_leaves,
            ),
        )(*args, **kwargs)

      if specialization.devices is None:
        devices = _infer_devices_from_args(args_leaves)
        if devices is None:
          raise ValueError(
              "No devices found. colocated_python function without input"
              " arguments must be first specialized with devices."
          )
        # Allow device polymorphism by applying devices specialization temporarily
        # for this call.
        return self._make_callable(
            info,
            specialization.update(devices=devices),
        )(*args, **kwargs)

      # Assertion is added to silence mypy error: Unsupported operand types for !=
      # ("PyTreeDef" and "None")  [operator]
      assert isinstance(specialization.in_specs_treedef, tree_util.PyTreeDef)

      # If input_specs is known, verify that it matches actual inputs.
      if (
          specialization.in_specs_treedef != in_specs_treedef
          or specialization.in_specs_leaves != in_specs_leaves
      ):
        raise ValueError(
            "Input specs in specialization and input specs of arguments must"
            " have the same pytree structure, but they have the following"
            " structural differences:\n"
            + (
                "\n".join(
                    f"   - {tree_util.keystr(path)} is a {thing1} in value 1"
                    f" and a {thing2} in  value 2, so {explanation}.\n"
                    for path, thing1, thing2, explanation in tree_util.equality_errors_pytreedef(
                        specialization.in_specs_treedef, in_specs_treedef
                    )
                )
            )
        )

      return self._get_specialized_func(info, specialization)(*args, **kwargs)

    __call__ = wraps(info.fun)(__call__)
    __call__.specialize = specialize
    return __call__

  def make_callable(
      self,
      fun: Callable[..., Any],
      fun_sourceinfo: str | None,
      fun_signature: inspect.Signature | None,
  ):
    """Makes a colocated Python callable."""
    return self._make_callable(
        FunctionInfo(fun, fun_sourceinfo, fun_signature), Specialization()
    )


_DEFAULT_FUNCTION_MAKER = _CachedColocatedFunctionMaker(None)


def make_callable(
    fun: Callable[..., Any],
    fun_sourceinfo: str | None,
    fun_signature: inspect.Signature | None,
):
  return _DEFAULT_FUNCTION_MAKER.make_callable(
      fun, fun_sourceinfo, fun_signature
  )
